package dao

import (
	"context"
	"github.com/ecodeclub/ekit/sqlx"
	"gorm.io/gorm"
	"gorm.io/gorm/clause"
	"time"
)

var ErrWaitingSMSNotFound = gorm.ErrRecordNotFound

//go:generate mockgen -source=./async_sms.go -package=daomocks -destination=mocks/async_sms_mock.go AsyncSmsDAO
type AsyncSmsDAO interface {
	// Insert 插入一条异步短信
	Insert(ctx context.Context, sms AsyncSms) error
	// GetWaitingSMS 获取一条待发送的短信
	GetWaitingSMS(ctx context.Context) (AsyncSms, error)
	// MarkSuccess 标记指定短信发送状态为成功
	MarkSuccess(ctx context.Context, id int64) error
	// MarkFailed 标记指定短信发送状态为失败（超过了重试次数，才会标记失败）
	MarkFailed(ctx context.Context, id int64) error
}

type AsyncSms struct {
	Id       int64
	Config   sqlx.EncryptColumn[SmsConfig] // 加密短信参数
	RetryCnt int                           // 重试次数
	RetryMax int                           // 最大重试次数
	Status   uint8                         // 发送状态
	Ctime    int64
	Utime    int64 `gorm:"index"`
}

type SmsConfig struct {
	TplID   string
	Args    []string
	Numbers []string
}

const ( // 发送状态
	asyncStatusWaiting = iota // 等待发送
	asyncStatusFailed         // 发送成功
	asyncStatusSuccess        // 发送失败
)

type GORMAsyncSmsDAO struct {
	db *gorm.DB
}

func NewGORMAsyncSmsDAO(db *gorm.DB) *GORMUserDAO {
	return &GORMUserDAO{db: db}
}

// Insert 插入一条异步短信
func (g *GORMAsyncSmsDAO) Insert(ctx context.Context, sms AsyncSms) error {
	return g.db.Create(&sms).Error
}

// GetWaitingSMS 获取一条待发送的短信
// 如果在高并发情况下, SELECT for UPDATE 对数据库的压力很大,但是我们不是高并发，
// 因为你部署N台机器，才有 N 个goroutine 来查询，并发不过百，随便写
func (g *GORMAsyncSmsDAO) GetWaitingSMS(ctx context.Context) (AsyncSms, error) {
	var asyncSms AsyncSms
	// 开启事务
	err := g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
		// 1. 取出一条待发送的异步短信
		// 为了避开一些偶发性的失败，我们只处理一分钟前的异步短信
		now := time.Now().UnixMilli()
		endTime := now - time.Minute.Milliseconds()
		// SELECT FOR UPDATE 上行级排他锁，避免其他事务修改该记录
		err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
			Where("utime<? and status=?", endTime, asyncStatusWaiting).
			First(&asyncSms).Error
		if err != nil {
			return err
		}
		// 2. 更新该记录的 utime
		err = tx.Model(&AsyncSms{}).Where("id=?", asyncSms.Id).
			Updates(map[string]any{
				"retry_cnt": gorm.Expr("retry_cnt + 1"),
				"utime":     now,
			}).Error
		return err
	})
	return asyncSms, err
}

// MarkSuccess 标记指定短信发送状态为成功
func (g *GORMAsyncSmsDAO) MarkSuccess(ctx context.Context, id int64) error {
	now := time.Now().UnixMilli()
	return g.db.WithContext(ctx).Model(&AsyncSms{}).
		Where("id=?", id).
		Updates(map[string]any{
			"utime":  now,
			"status": asyncStatusSuccess,
		}).Error
}

// MarkFailed 标记指定短信发送状态为失败（超过了重试次数，才会标记失败）
func (g *GORMAsyncSmsDAO) MarkFailed(ctx context.Context, id int64) error {
	now := time.Now().UnixMilli()
	return g.db.WithContext(ctx).Model(&AsyncSms{}).
		Where("id=? and `retry_cnt`>=`retry_max` ", id).
		Updates(map[string]any{
			"utime":  now,
			"status": asyncStatusFailed,
		}).Error
}
